import base64
import datetime
import json
import os
import re
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union

import cv2
import numpy as np
import requests
import yaml
from loguru import logger as eval_logger
from openai import AzureOpenAI, OpenAI

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

hf_home = os.getenv("HF_HOME", "~/.cache/huggingface")

base_cache_dir = os.path.expanduser(hf_home)


with open(Path(__file__).parent / "mmvu_val.yaml", "r") as f:
    raw_data_val = f.readlines()
    safe_data_val = []
    for i, line in enumerate(raw_data_val):
        # remove function definition since yaml load cannot handle it
        if "!function" not in line:
            safe_data_val.append(line)
cache_name_val = yaml.safe_load("".join(safe_data_val))["dataset_kwargs"]["cache_dir"]
cache_dir_val = os.path.join(base_cache_dir, cache_name_val)


def mmvu_doc_to_visual_val(doc):
    video_path = doc["video_path"]
    video_path = os.path.join(cache_dir_val, video_path)
    if os.path.exists(video_path):
        video_path = video_path
    else:
        sys.exit(f"video path:{video_path} does not exist, please check")
    return [video_path]


multiple_choice_prompt = """
Question:{question}
A: {a}
B: {b}
C: {c}
D: {d}
E: {e}
Visual Information: processed video
Do not generate any intermediate reasoning process. Answer directly with the option letter from the
given choices.
"""

open_ended_prompt = """
Question:{question}
Visual Information: processed video
Do not generate any intermediate reasoning process. Directly output the final answer.
"""

multiple_choice_prompt_cot = """
Question:{question}
A: {a}
B: {b}
C: {c}
D: {d}
E: {e}
Visual Information: processed video
Answer the given multiple-choice question step by step. Begin by explaining your reasoning process
clearly. Conclude by stating the final answer using the following format: "Therefore, the final answer
is: $LETTER" (without quotes), where $LETTER is one of the options. Think step by step before
answering.
"""

open_ended_prompt_cot = """
Question:{question}
Visual Information: processed video
Answer the given question step by step. Begin by explaining your reasoning process clearly. Conclude
by stating the final answer using the following format: "Therefore, the final answer is: "Answer:
$ANSWER" (without quotes), where $ANSWER is the final answer of the question. Think step by
step before answering.
"""


def mmvu_doc_to_text(doc, lmms_eval_specific_kwargs=None):
    question_type = doc["question_type"]
    if question_type == "multiple_choice":
        question = doc["question"]
        choices = doc["choices"]
        full_prompt = multiple_choice_prompt.format(question=question, a=choices["A"], b=choices["B"], c=choices["C"], d=choices["D"], e=choices["E"])
    else:
        question = doc["question"]
        full_prompt = open_ended_prompt.format(question=question)
    return full_prompt


def mmvu_doc_to_text_cot(doc, lmms_eval_specific_kwargs=None):
    question_type = doc["question_type"]
    if question_type == "multiple_choice":
        question = doc["question"]
        choices = doc["choices"]
        full_prompt = multiple_choice_prompt_cot.format(question=question, a=choices["A"], b=choices["B"], c=choices["C"], d=choices["D"], e=choices["E"])
    else:
        question = doc["question"]
        full_prompt = open_ended_prompt_cot.format(question=question)
    return full_prompt


mcq_eval_prompt = """
[Instruction]
Evaluate whether the model's final answer is correct by comparing it to the ground-truth answer provided for the given question.
You should first extract the final answer from the model's response, and then compare the extracted
answer with the ground-truth answer to determine its accuracy. Output your response in the following
structured format:
{{
"extracted answer": // str value "A" "B" "C" "D" "E", should be a single character
"correct": // boolean value, true if the answer is correct, false otherwise
}}
[User]
Question:{question}
A: {a}
B: {b}
C: {c}
D: {d}
E: {e}
Ground Truth Answer: {ground_truth}
Model Response to the Question: {model_response}
"""

open_ended_eval_prompt = """
[Instruction]
Evaluate whether the model's final answer is correct by comparing it to the ground-truth answer
provided for the given question. You should first extract the final answer from the model's response,
and then compare the extracted answer with the ground-truth answer to determine its accuracy. The
final answer generated by the model does not need to match the ground-truth answer word-for-word.
However, it should only be considered correct if it demonstrates the exact same technique or concept
explicitly and unambiguously equivalent to the ground-truth answer. Output your response in the
following structured format:
{{
"extracted answer": // str value, the short final answer extracted from the model's response, do not
hallucinate one that is not present in the response
"correct": // boolean value, true if the answer is correct, false otherwise
}}
[User]
Question:{question}
Ground Truth Answer: {ground_truth}
Model Response to the Question: {model_response}
"""

MAX_ITER = 5
NUM_SECONDS_TO_SLEEP = 1
API_TYPE = os.getenv("API_TYPE", "azure")
if API_TYPE == "openai":
    endpoint = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
    deployment = os.getenv("DEPLOYMENT_NAME", "gpt-4o")
    subscription_key = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
    client = OpenAI(
        api_key=subscription_key,
        api_base=endpoint,
        api_version="2025-01-01-preview",
    )
elif API_TYPE == "azure":
    endpoint = os.getenv("ENDPOINT_URL", "your_endpoint_url")
    deployment = os.getenv("DEPLOYMENT_NAME", "gpt-4o")
    subscription_key = os.getenv("AZURE_OPENAI_API_KEY", "your_api_key")
    client = AzureOpenAI(
        azure_endpoint=endpoint,
        api_key=subscription_key,
        api_version="2025-01-01-preview",
    )
else:
    raise ValueError(f"Unsupported API_TYPE: {API_TYPE}. Please set it to 'openai' or 'azure'.")


def gpt_parser(response, doc):
    question_type = doc["question_type"]
    if question_type == "multiple-choice":
        prompt = mcq_eval_prompt.format(
            question=doc["question"],
            a=doc["choices"]["A"],
            b=doc["choices"]["B"],
            c=doc["choices"]["C"],
            d=doc["choices"]["D"],
            e=doc["choices"]["E"],
            ground_truth=doc["answer"] + " " + doc["choices"][doc["answer"]],
            model_response=response,
        )
    else:
        prompt = open_ended_eval_prompt.format(question=doc["question"], ground_truth=doc["answer"], model_response=response)

    prompt_message = [
        {
            "role": "user",
            "content": prompt,
        }
    ]

    params = {
        "model": "gpt-4o",
        "messages": prompt_message,
        "max_tokens": 512,
        "temperature": 0.0,
    }

    try:
        response = client.chat.completions.create(**params)
        response_text = response.choices[0].message.content
        eval_logger.debug(f"Raw GPT response: {response_text}")
        return json.loads(response_text)

    except Exception as e:
        print(response)
        eval_logger.error(f"Error parsing GPT response: {e}")
        return None


def extract_category(doc):
    category = doc["video_path"].split("/")[-2]
    return category


def mmvu_process_results(doc, results):
    """
    Args:
        doc: a instance of the eval dataset
        results: [pred]
    Returns:
        a dictionary with key: metric name (in this case videomme score), value: metric value
    """
    pred = results[0]
    pred_ans = pred
    category = extract_category(doc)
    curr_iter = 0
    parsed_response = None
    while parsed_response is None and curr_iter < MAX_ITER:
        parsed_response = gpt_parser(pred_ans, doc)
        curr_iter += 1
        time.sleep(NUM_SECONDS_TO_SLEEP)
    if parsed_response is None:
        parsed_response = {"extracted answer": "N/A", "correct": False}

    pred_ans = parsed_response.get("extracted answer", "N/A")
    correct = parsed_response.get("correct", False)

    data_dict = {"question_id": doc["id"], "category": category, "pred_answer": pred_ans, "answer": doc["answer"], "correct": correct}

    return {f"accuracy": data_dict}


def mmvu_aggregate_results_val(results):
    """
    Args:
        results: a list of values returned by process_results
    Returns:
        A score
    """

    TASK_MAP = {
        "Biology": "Science",
        "Chemistry": "Science",
        "Modern_Physics": "Science",
        "Astronomy": "Science",
        "Geography": "Science",
        "Materials_Science": "Science",
        "Neurobiology": "Science",
        "Electromagnetism": "Science",
        "Thermodynamics": "Science",
        "Mechanics": "Science",
        "Civil_Engineering": "Engineering",
        "Electrical_Engineering": "Engineering",
        "Mechanical_Engineering": "Engineering",
        "Biomedical_Engineering": "Engineering",
        "Electronics_and_Communication": "Engineering",
        "Computer_Science": "Engineering",
        "Clinical_Medicine": "Healthcare",
        "Basic_Medicine": "Healthcare",
        "Preventive_Medicine": "Healthcare",
        "Pharmacy": "Healthcare",
        "Dentistry": "Healthcare",
        "Art": "Humanities_and_Social_Science",
        "Literature": "Humanities_and_Social_Science",
        "History": "Humanities_and_Social_Science",
        "Law": "Humanities_and_Social_Science",
        "Economics": "Humanities_and_Social_Science",
        "Management": "Humanities_and_Social_Science",
    }

    TASK_TYPES = list(set(TASK_MAP.values()))

    category2score = {}
    for task_type in TASK_TYPES:
        category2score[task_type] = {"correct": 0, "answered": 0}

    for result in results:
        category = result["category"]
        if category in TASK_MAP:
            category = TASK_MAP[category]
            category2score[category]["answered"] += 1
            category2score[category]["correct"] += result.get("correct", False)
    category_scores = {}

    for category in TASK_TYPES:
        total_correct = category2score[category]["correct"]
        total_answered = category2score[category]["answered"]
        accuracy = 100 * total_correct / total_answered if total_answered > 0 else 0
        category_scores[category] = accuracy

    total_correct = sum(category2score[category]["correct"] for category in TASK_TYPES)
    total_answered = sum(category2score[category]["answered"] for category in TASK_TYPES)
    accuracy = 100 * total_correct / total_answered if total_answered > 0 else 0
    eval_logger.info("=" * 50)
    eval_logger.info(f"Average Accuracy: {accuracy:.2f}%")
    eval_logger.info("Categorical accuracy: ")
    for key, value in category_scores.items():
        eval_logger.info(f"{key} accuracy: {value:.2f}%")
    eval_logger.info("=" * 50)
    return accuracy
